Skip to content

Add nnx.prefix#5325

Open
samanklesaria wants to merge 2 commits intogoogle:mainfrom
samanklesaria:prefix
Open

Add nnx.prefix#5325
samanklesaria wants to merge 2 commits intogoogle:mainfrom
samanklesaria:prefix

Conversation

@samanklesaria
Copy link
Copy Markdown
Collaborator

@samanklesaria samanklesaria commented Mar 12, 2026

Followup to #5270. Adds nnx.prefix, which can be used to provide 'in_axes' specifications based on a filters. Note that this depends on a separate PR for temporary configuration changes.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a new nnx.prefix utility to facilitate filter-based in_axes specifications for JAX transformations, enhancing flexibility in how models are processed. Concurrently, it refactors the random number generation (RNG) key splitting mechanism within Flax NNX, deprecating the split argument in fork methods and introducing dedicated split methods for clearer and more consistent API usage. These changes streamline the handling of model transformations and RNG management.

Highlights

  • New nnx.prefix function: Introduced nnx.prefix in flax.nnx.pytreelib which allows for creating in_axes specifications for JAX transformations (like vmap) based on filters, enabling more granular control over how different parts of a Pytree are mapped.
  • RNG Splitting Refactoring: The functionality for splitting RNG keys has been refactored. A new split method was added to RngStream and Rngs classes, providing a dedicated way to split RNG keys. The split argument within the fork methods of RngStream and Rngs has been deprecated in favor of these new split methods.
  • Config Flag Flexibility: The temp_flip_flag context manager in flax.configurations now accepts an optional prefix argument, allowing it to temporarily modify configuration flags with custom prefixes beyond just 'flax_'.
  • Documentation and Examples Updated: Documentation and examples for Rngs.fork and nnx.fork_rngs have been updated to reflect the removal of the split argument and to guide users towards the new split methods.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • flax/configurations.py
    • Modified temp_flip_flag to accept an optional prefix argument, allowing for more flexible configuration flag manipulation.
    • Added a newline at the end of the file.
  • flax/nnx/init.py
    • Imported the new prefix function from flax.nnx.pytreelib.
  • flax/nnx/pytreelib.py
    • Imported filterlib to support filter-based operations.
    • Added the prefix function, which maps a pytree with a filter map to apply prefixes based on predicates.
  • flax/nnx/rnglib.py
    • Imported the warnings module for deprecation notices.
    • Added a split method to RngStream for splitting RNG keys.
    • Deprecated the split argument in RngStream.fork with a warning, advising users to use the new split method.
    • Introduced a new split method in Rngs to handle splitting RNG keys based on filters, replacing previous fork functionality.
    • Deprecated the split argument in Rngs.fork with a warning, directing users to the new split method.
    • Removed the split argument from the fork_rngs function signature and updated its documentation and examples to reflect this change.
  • tests/nnx/prefix_test.py
    • Added a new test file to verify the functionality of the nnx.prefix utility, specifically demonstrating its use with jax.vmap and Rngs.split.
  • tests/nnx/rngs_test.py
    • Updated test_fork_rngs to remove the split argument from the nnx.fork_rngs call and verify that the key changes after forking.
  • tests/nnx/transforms_test.py
    • Updated an example in test_vmap_rngs_module to use rngs.split(5) instead of rngs.fork(split=5).
Activity
  • The pull request introduces new functionality and refactors existing code, indicating initial development and preparation for review. No human activity (comments, reviews) has been recorded yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the RNG splitting mechanism in Flax NNX by introducing a new Rngs.split method to replace the split argument in Rngs.fork and nnx.fork_rngs, which are now deprecated. A new prefix utility function is added to pytreelib for applying prefixes to pytrees based on filters, and the temp_flip_flag context manager is enhanced with a prefix argument for more flexible configuration. Review comments highlight a potential bug in the new Rngs.split method where RngStream objects might be unintentionally shared instead of being properly forked, and suggest strengthening a test case for fork_rngs to ensure correct restoration of RNG states.

@samanklesaria samanklesaria force-pushed the prefix branch 2 times, most recently from 2ab65b7 to 0a684c5 Compare March 16, 2026 15:33
@samanklesaria
Copy link
Copy Markdown
Collaborator Author

samanklesaria commented Mar 16, 2026

Interestingly, the following works:

      def my_fn(rngs):
        model = Model(rngs)
        return (model, rngs)
      prefix = nnx.prefix(rngs, {'dropout': 0})
      model, new_rngs = jax.vmap(my_fn, out_axes=(0, prefix), in_axes=(prefix,))(rngs)

So why does nnx.vmap, which is basically doing the same thing in tree mode, not work?

As far as I can tell, this is because the updates returned by nnx.vmap don't have exactly the same pytree shape as the inputs (and therefore don't match the pytree prefixes of the outputs). Specifically, there are Nones wherever the inputs have been left unchanged. And jax can't handle this discrepancy.

@samanklesaria
Copy link
Copy Markdown
Collaborator Author

Found a fix. The issue is that when we mask variables that haven't changed, we need the prefix for these positions to be None. So when constructing the prefix, we just need to ensure that we treat nnx.Variable nodes as leaves rather than recursing into them.

@samanklesaria samanklesaria marked this pull request as ready for review March 17, 2026 20:09
@samanklesaria samanklesaria force-pushed the prefix branch 3 times, most recently from fd9674e to 99e12da Compare March 20, 2026 15:07
return obj
return None

return jax.tree.map_with_path(lookup, pytree,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recently added nnx.map which can be used here, else you have to convert from the jax path format to the nnx path format

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would require that the prefix filter only select Variable nodes. Which is probably what we usually want, but a little less flexible than using jax.tree.map_with_path as I have here.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get a strange error if I try to switch to nnx.map:

E     At that key path, the prefix pytree vmap out_axes has a subtree of type
E         <class 'flax.nnx.rnglib.RngKey'>
E     but at the same key path the full pytree has a subtree of different type
E         <class 'flax.nnx.extract.Mask'>.

What's nnx.extract.Mask? I guess I have some digging to do.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, I'm just converting the jax path format to the nnx one, which makes the tests pass. But I'll investigate why nnx.map produces different behavior.

@samanklesaria
Copy link
Copy Markdown
Collaborator Author

Added an additional test which is failing. Will debug.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants